import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.autograd as grad

class Transformer(nn.Module):
	def __init__(self, num_points=2000, K=3):
		super(Transformer, self).__init__()

		self.K = K

		self.N = num_points

		self.identity = grad.Variable(
			torch.eye(self.K).double().view(-1).cuda())

		self.block1 =nn.Sequential(
			nn.Conv1d(K, 64, 1),
			nn.BatchNorm1d(64),
			nn.ReLU())
		self.block2 =nn.Sequential(
			nn.Conv1d(64, 128, 1),
			nn.BatchNorm1d(128),
			nn.ReLU())

		self.block3 =nn.Sequential(
			nn.Conv1d(128, 1024, 1),
			nn.BatchNorm1d(1024),
			nn.ReLU())
		self.mlp = nn.Sequential(
			nn.Linear(1024, 512),
			nn.BatchNorm1d(512),
			nn.ReLU(),
			nn.Linear(512, 256),
			nn.BatchNorm1d(256),
			nn.ReLU(),
			nn.Linear(256, K * K))

	def forward(self, x):
		x = self.block1(x)
		x = self.block2(x)
		x = self.block3(x)
		x = F.max_pool1d(x, self.N).squeeze(2)
		x = self.mlp(x)

		x += self.identity
		x = x.view(-1, self.K, self.K)

		return x

